In this project, we analyse determinants of song popularity from a dataset on Spotify tracks.
In particular, our original dataset covers 114000 tracks. Each track has 21 audio features associated with it, ranging from artist name, popularity, duration, genre, ‘acousticness’, and tempo. All measures that cannot be measured directly such as ‘acousticness’, ‘danceability’, ‘instrumentalness’, have been normalised to a scale of 0-1.
We feel it would be interesting to see what factors affect ‘popuarity’, and believe it is likely to be determined by the other regressors in the data set such as ‘energy’, ‘danceability’, ‘valence’ etc. This could produce valuable models by predicting which songs people will enjoy before they’ve become popular, based on the characteristic or ‘intrinsic’ value of the song and less so about the artist names attached to it. Hence this can help with ‘song recommendation’ features.
It is reasonable to assume each track is independent of another, given that songs are usually written based on new concepts. We can also assume they share the same probability distribution, since all songs are judged based on the same critera, all of which are normalised to the same scale. Hence it is reasonable to assune they are identically distrubuted.
Dataset source: https://www.kaggle.com/datasets/maharshipandya/-spotify-tracks-dataset
Notes: - track_id unique - track_name not unique (keep different versions by different singers, rid of different versions by the same singer)
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.2 ──
## ✔ ggplot2 3.4.0 ✔ purrr 0.3.4
## ✔ tibble 3.1.8 ✔ dplyr 1.0.10
## ✔ tidyr 1.2.1 ✔ stringr 1.4.1
## ✔ readr 2.1.2 ✔ forcats 0.5.2
## Warning: package 'ggplot2' was built under R version 4.2.2
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
library(GGally) # ggpairs()
## Registered S3 method overwritten by 'GGally':
## method from
## +.gg ggplot2
library(corrplot) # corrplot()
## Warning: package 'corrplot' was built under R version 4.2.2
## corrplot 0.92 loaded
library(gridExtra) # grid.arrange()
##
## Attaching package: 'gridExtra'
##
## The following object is masked from 'package:dplyr':
##
## combine
library(ggplot2) # gm_scatterplot
library(tidymodels) # initial_split()
## Warning: package 'tidymodels' was built under R version 4.2.2
## ── Attaching packages ────────────────────────────────────── tidymodels 1.0.0 ──
## ✔ broom 1.0.1 ✔ rsample 1.1.0
## ✔ dials 1.1.0 ✔ tune 1.0.1
## ✔ infer 1.0.3 ✔ workflows 1.1.2
## ✔ modeldata 1.0.1 ✔ workflowsets 1.0.0
## ✔ parsnip 1.0.3 ✔ yardstick 1.1.0
## ✔ recipes 1.0.3
## Warning: package 'dials' was built under R version 4.2.2
## Warning: package 'infer' was built under R version 4.2.2
## Warning: package 'modeldata' was built under R version 4.2.2
## Warning: package 'parsnip' was built under R version 4.2.2
## Warning: package 'recipes' was built under R version 4.2.2
## Warning: package 'rsample' was built under R version 4.2.2
## Warning: package 'tune' was built under R version 4.2.2
## Warning: package 'workflows' was built under R version 4.2.2
## Warning: package 'workflowsets' was built under R version 4.2.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ gridExtra::combine() masks dplyr::combine()
## ✖ scales::discard() masks purrr::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ recipes::fixed() masks stringr::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ yardstick::spec() masks readr::spec()
## ✖ recipes::step() masks stats::step()
## • Learn how to get started at https://www.tidymodels.org/start/
library(glmnet) # glmnet()
## Warning: package 'glmnet' was built under R version 4.2.2
## Loading required package: Matrix
## Warning: package 'Matrix' was built under R version 4.2.2
##
## Attaching package: 'Matrix'
##
## The following objects are masked from 'package:tidyr':
##
## expand, pack, unpack
##
## Loaded glmnet 4.1-4
data_new <- read.csv("unique_tracks_genres.csv")
final_data <- select(data_new, -X, -track_id, -artists, -album_name, -track_name, -track_genre)
# convert into factors
final_data$explicit <- as.numeric(as.factor(final_data$explicit))-1 # 0 for FALSE, 1 for TRUE
final_data$mode <- as.integer(final_data$mode) # 0 for minor, 1 for major
# create new dummy showing if track has >1 genre
genres <- c("pop", "rock", "country", "jazz", "electronic", "classical", "world", "kids", "other", "rap")
final_data$two_genre <- as.numeric(rowSums(final_data[, genres]) == 2)
final_data$three_genre <- as.numeric(rowSums(final_data[, genres]) == 3)
final_data$four_genre <- as.numeric(rowSums(final_data[, genres]) == 4)
final_data$five_genre <- as.numeric(rowSums(final_data[, genres]) == 5)
summary(final_data)
## popularity duration_ms explicit danceability
## Min. : 0.00 Min. : 0 Min. :0.00000 Min. :0.0000
## 1st Qu.: 21.00 1st Qu.: 173871 1st Qu.:0.00000 1st Qu.:0.4460
## Median : 35.00 Median : 215213 Median :0.00000 Median :0.5730
## Mean : 34.76 Mean : 231419 Mean :0.08544 Mean :0.5593
## 3rd Qu.: 49.00 3rd Qu.: 267344 3rd Qu.:0.00000 3rd Qu.:0.6900
## Max. :100.00 Max. :5237295 Max. :1.00000 Max. :0.9850
## energy key loudness mode
## Min. :0.000 Min. : 0.000 Min. :-49.531 Min. :0.0000
## 1st Qu.:0.455 1st Qu.: 2.000 1st Qu.:-10.456 1st Qu.:0.0000
## Median :0.678 Median : 5.000 Median : -7.263 Median :1.0000
## Mean :0.635 Mean : 5.285 Mean : -8.596 Mean :0.6323
## 3rd Qu.:0.857 3rd Qu.: 8.000 3rd Qu.: -5.142 3rd Qu.:1.0000
## Max. :1.000 Max. :11.000 Max. : 4.532 Max. :1.0000
## speechiness acousticness instrumentalness liveness
## Min. :0.0000 Min. :0.0000 Min. :0.0000000 Min. :0.0000
## 1st Qu.:0.0361 1st Qu.:0.0159 1st Qu.:0.0000000 1st Qu.:0.0985
## Median :0.0491 Median :0.1900 Median :0.0000886 Median :0.1330
## Mean :0.0890 Mean :0.3297 Mean :0.1847155 Mean :0.2197
## 3rd Qu.:0.0870 3rd Qu.:0.6290 3rd Qu.:0.1530000 3rd Qu.:0.2830
## Max. :0.9650 Max. :0.9960 Max. :1.0000000 Max. :1.0000
## valence tempo time_signature pop
## Min. :0.0000 Min. : 0.0 Min. :0.000 Min. :0.0000
## 1st Qu.:0.2410 1st Qu.: 99.4 1st Qu.:4.000 1st Qu.:0.0000
## Median :0.4490 Median :122.0 Median :4.000 Median :0.0000
## Mean :0.4633 Mean :122.1 Mean :3.897 Mean :0.1571
## 3rd Qu.:0.6760 3rd Qu.:140.1 3rd Qu.:4.000 3rd Qu.:0.0000
## Max. :0.9950 Max. :243.4 Max. :5.000 Max. :1.0000
## rock country jazz electronic
## Min. :0.0000 Min. :0.00000 Min. :0.00000 Min. :0.000
## 1st Qu.:0.0000 1st Qu.:0.00000 1st Qu.:0.00000 1st Qu.:0.000
## Median :0.0000 Median :0.00000 Median :0.00000 Median :0.000
## Mean :0.1967 Mean :0.06841 Mean :0.08189 Mean :0.252
## 3rd Qu.:0.0000 3rd Qu.:0.00000 3rd Qu.:0.00000 3rd Qu.:1.000
## Max. :1.0000 Max. :1.00000 Max. :1.00000 Max. :1.000
## classical world kids other
## Min. :0.000 Min. :0.000 Min. :0.00000 Min. :0.0000
## 1st Qu.:0.000 1st Qu.:0.000 1st Qu.:0.00000 1st Qu.:0.0000
## Median :0.000 Median :0.000 Median :0.00000 Median :0.0000
## Mean :0.065 Mean :0.304 Mean :0.03445 Mean :0.1343
## 3rd Qu.:0.000 3rd Qu.:1.000 3rd Qu.:0.00000 3rd Qu.:0.0000
## Max. :1.000 Max. :1.000 Max. :1.00000 Max. :1.0000
## rap two_genre three_genre four_genre
## Min. :0.00000 Min. :0.0000 Min. :0.00000 Min. :0.000000
## 1st Qu.:0.00000 1st Qu.:0.0000 1st Qu.:0.00000 1st Qu.:0.000000
## Median :0.00000 Median :0.0000 Median :0.00000 Median :0.000000
## Mean :0.03698 Mean :0.2338 Mean :0.02991 Mean :0.009724
## 3rd Qu.:0.00000 3rd Qu.:0.0000 3rd Qu.:0.00000 3rd Qu.:0.000000
## Max. :1.00000 Max. :1.0000 Max. :1.00000 Max. :1.000000
## five_genre
## Min. :0.000000
## 1st Qu.:0.000000
## Median :0.000000
## Mean :0.001672
## 3rd Qu.:0.000000
## Max. :1.000000
col_names <-names(final_data)
for (i in seq_along(col_names)){
hist(final_data[,i], main=paste("Histogram of", col_names[[i]]))
}
final_data_cor1 <- cor(final_data)
corrplot(final_data_cor1, method="square", col = rev(colorRampPalette(c("#B40F20", "#FFFFFF", "#2E3A87"))(100)), type="lower", tl.col="black", tl.srt=60, tl.cex = 0.6)
The features selected are selected based on high absolute correlation between factors in the correlation plot.
ggpairs(final_data, columns = c("popularity", "danceability", "loudness", "instrumentalness"), lower = list(continuous = "smooth"), upper = list(continuous = "cor"))
Popularity vs: danceability, energy, speechiness, acousticness, instrumentalness, liveness, valence
basic_plots <- function(x){
# plot without transparency
plot_nt <- ggplot(final_data, aes(x = !!sym(x), y = popularity)) +
geom_point(alpha = 0.1)
# plot with transparency
plot_wt <- ggplot(final_data, aes(x = !!sym(x), y = popularity)) +
geom_bin2d(alpha = 0.7) +
scale_fill_gradientn(colors = c("#440154", "#30678D", "#35B778", "#FDE724", "#FFFFFF"))
# Return both plots
return(list(plot_nt, plot_wt))
}
metrics <- c('danceability', 'energy', 'speechiness', 'acousticness', 'instrumentalness', 'liveness', 'valence')
for (i in metrics) {
plots <- basic_plots(i)
grid.arrange(plots[[1]], plots[[2]], ncol = 2)
}
# Assign the genre name based on the dummy variables
get_genre_name <- function(x) {
ifelse(x["two_genre"] == 1, "2_genres",
ifelse(x["three_genre"] == 1, "3_genres",
ifelse(x["four_genre"] == 1, "4_genres",
ifelse(x["five_genre"] == 1, "5_genres",
ifelse(x["rock"] == 1, "rock",
ifelse(x["country"] == 1, "country",
ifelse(x["jazz"] == 1, "jazz",
ifelse(x["electronic"] == 1, "electronic",
ifelse(x["classical"] == 1, "classical",
ifelse(x["world"] == 1, "world",
ifelse(x["kids"] == 1, "kids",
ifelse(x["other"] == 1, "other",
ifelse(x["rap"] == 1, "rap", "pop")))))))))))))
}
# Apply the function to each row of the data frame and create a new column with the genre names
temp_data <- data.frame(final_data)
temp_data$genre_name <- apply(final_data[, -1], 1, get_genre_name)
# Create a bar plot of mean popularity by genre
mean_popularity <- tapply(temp_data$popularity, temp_data$genre_name, mean)
barplot(mean_popularity, xlab = "Genre", ylab = "Mean Popularity", col = "steelblue", main = "Mean Popularity by Genre", las = 2, cex.names = 0.8)
# Define X, y, data
X <- select(final_data, -1)
y <- final_data$popularity
data <- data.frame(y = y, X = X)
# Split data into training and test set
data_split <- initial_split(data)
data_train <- training(data_split)
data_test <- testing(data_split)
# Cross-validation for tuning the parameters
data_cv <- vfold_cv(data_train, v = 10)
# Pre-process the model
data_recipe <- data_train %>%
recipe(y ~ .) %>%
prep()
Simple baseline for comparison to the more sophisticated models. Here we have chosen linear regression.
baseline <- lm(y ~ X.explicit + X.danceability + X.instrumentalness, data = data_train)
predictions_baseline <- predict(baseline, newdata = data_test)
# Test metrics ------------------
RMSE_baseline <- sqrt(mean((data_test$y - predictions_baseline)^2))
RSQ_baseline <- cor(data_test$y, predictions_baseline)^2
# Print the value
print("Testing: ")
## [1] "Testing: "
cat("RMSE:", RMSE_baseline, "\n")
## RMSE: 18.89782
cat("R-squared:", RSQ_baseline, "\n")
## R-squared: 0.03682725
# Training metrics
# Get summary statistics ------------------
summary_stats <- summary(baseline)
# Extract RMSE and R-squared values
RMSE_baseline_train <- sqrt(mean(summary_stats$residuals^2))
RSQ_baseline_train <- summary_stats$r.squared
print("Training: ")
## [1] "Training: "
cat("RMSE:", RMSE_baseline_train, "\n")
## RMSE: 19.00617
cat("R-squared:", RSQ_baseline_train, "\n")
## R-squared: 0.03649003
Non-baseline model that is (relatively) interpretable.
linear_reg() function from the parsnip
packagetune() used to specify the hyperparameters
penalty (P) and mixture (M)set_engine() used to specify the modeling engine used
to fit the model (here we use glmnet)pen_reg_y is a model specification
object that can be further used for model training, tuning and
prediction# Model specification = penalised linear regression
pen_reg_y <- linear_reg(penalty = tune('P'), mixture = tune('M')) %>%
set_engine('glmnet')
# Set up the workflow
pen_reg_wf <- workflow() %>%
add_recipe(data_recipe) %>%
add_model(pen_reg_y)
# Tune the parameters
fit_pen_reg <- tune_grid(pen_reg_wf,
#grid = data.frame(P = 2^seq(-3, 2, by = 1),
#M = seq(0, 1, by = 0.2)),
data_cv,
metrics = metric_set(rmse, mae, rsq),
control = control_grid(save_pred = TRUE))
fit_pen_reg %>% autoplot() # plot the result for each value of the parameters
# Select the best model with the smallest cross-validation rmse
pen_reg_best <- fit_pen_reg %>%
select_best(metric = 'rmse')
pen_reg_best # print the best model
## # A tibble: 1 × 3
## P M .config
## <dbl> <dbl> <chr>
## 1 0.0199 0.710 Preprocessor1_Model07
### After getting the best parameter, can now return to the normal function
# Fit the final model
pen_reg_final <- finalize_model(pen_reg_y, pen_reg_best)
# Predict on the test data with the final model
pen_reg_test <- pen_reg_wf %>%
update_model(pen_reg_final) %>%
last_fit(split = data_split) %>%
collect_metrics()
pen_reg_test # print the result
## # A tibble: 2 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 18.1 Preprocessor1_Model1
## 2 rsq standard 0.115 Preprocessor1_Model1
P_best <- pen_reg_best[1]
M_best <- pen_reg_best[2]
glmnet_best <- glmnet(select(data_train, -1), data_train$y,
family = "gaussian",
alpha = M_best)
glmnet_lasso <- glmnet(select(data_train, -1), data_train$y,
family = "gaussian",
alpha = 1)
glmnet_ridge <- glmnet(select(data_train, -1), data_train$y,
family = "gaussian",
alpha = 0)
plot(glmnet_best, xvar = "lambda")
Interpretation:
Comparison to baseline model: Predictive accuracy better. This can be seen through the lower RMSE. The R-Squared has also improved from 0.03 to 0.11. Despite this, it is still very low. This gives us reason to think perhaps the relationship is not linear. Therefore, one of the models that followed that was complex and non-linear: Random forest